{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Simple Grad",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rnvTWs4W4Hea"
      },
      "source": [
        "# Simple Autograd\n",
        "\n",
        "This notebook walks through a self-contained implementation of reverse mode auto-differentiation. The intention is to make it easier to understand PyTorch's implementation of auto-diff and how TorchScript interacts with it without having to work through all the complexity that the real implementation contains.\n",
        "\n",
        "\n",
        "To get started, we import some helper functions."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2Lf8JaEK4MHJ"
      },
      "source": [
        "import torch\n",
        "from typing import List, NamedTuple, Callable, Dict, Optional\n",
        "\n",
        "_name: int = 0\n",
        "def fresh_name() -> str:\n",
        "    \"\"\" create a new unique name for a variable: v0, v1, v2 \"\"\"\n",
        "    global _name\n",
        "    r = f'v{_name}'\n",
        "    _name += 1\n",
        "    return r\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "erjC686T4S4c"
      },
      "source": [
        "To make it possible to fully understand, this system does not rely on PyTorch's autograd at all. It only uses the Tensor object to do compute. We add our own Variable class to track the gradients of computation, and a `grad` function compute gradients.\n",
        "\n",
        "\n",
        "Similar to PyTorch, we use tape-based reverse mode auto-differentiation to\n",
        "compute the gradient. For some scalar loss `l`, we will compute the value `dl/dX` for \n",
        "_every_ value `X` computed in the program (`l` is always a scalar, but the `X`s can be tensors).\n",
        "We do this by starting with `dl/dl == 1`, and use the partial derivatives plus \n",
        "the chain rule to propagate the values backward,\n",
        "e.g. `dl/dx * dx/dy = dl/dy`.\n",
        "\n",
        "https://sidsite.com/posts/autodiff/ might be a good place to start if you \n",
        "haven't seen reverse mode auto-diff before.\n",
        "\n",
        "For the purpose of this example, we primarily use point-wise tensor operators like `+` to keep the partial derivatives simple.\n",
        "\n",
        "\n",
        "# The Implementation\n",
        "\n",
        "Variable is a wrapper around Tensor that tracks the compute.\n",
        " Each variable has a globally unique name so that we can track the gradient\n",
        "for this Variable in a dictionary. For ease of understanding,\n",
        "we sometimes provide this name as an argument. Otherwise, we \n",
        "generate a fresh temporary each time.\n",
        "        "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sfaIcdiXEZpO"
      },
      "source": [
        "class Variable:\n",
        "    def __init__(self, value : torch.Tensor, name: str=None):\n",
        "        self.value = value\n",
        "        self.name = name or fresh_name()\n",
        "\n",
        "    # We need to start with some tensors whose values were not computed\n",
        "    # inside the autograd. This function constructs leaf nodes. \n",
        "    @staticmethod\n",
        "    def constant(value: torch.Tensor, name: str=None):\n",
        "        r = Variable(value, name)\n",
        "        print(f'{r.name} = {value}')\n",
        "        return r\n",
        "\n",
        "    def __repr__(self):\n",
        "        return repr(self.value)\n",
        "\n",
        "\n",
        "    # This performs a pointwise multiplication of a Variable, tracking gradients\n",
        "    def __mul__(self, rhs: 'Variable') -> 'Variable':\n",
        "        # defined later in the notebook\n",
        "        return operator_mul(self, rhs)\n",
        "\n",
        "    def __add__(self, rhs: 'Variable') -> 'Variable':\n",
        "        return operator_add(self, rhs)\n",
        "            \n",
        "    def sum(self, name: Optional[str]=None) -> 'Variable':\n",
        "        return operator_sum(self, name)\n",
        "    \n",
        "    def expand(self, sizes: List[int]) -> 'Variable':\n",
        "        return operator_expand(self, sizes)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o4gBKqAGEoHD"
      },
      "source": [
        "We need to keep track of all the computation so we can apply the\n",
        "chain rule backward. A tape entry will help is do this."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y5QC7spUEuVV"
      },
      "source": [
        "class TapeEntry(NamedTuple):\n",
        "    # names of the inputs to the original computation\n",
        "    inputs : List[str]\n",
        "    # names of the outputs of the original computation\n",
        "    outputs: List[str]\n",
        "    # apply chain rule\n",
        "    propagate: 'Callable[List[Variable], List[Variable]]'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "byeB7Cg_E8Kz"
      },
      "source": [
        "The `inputs` and `outputs` are the unique names of the Variables that are inputs and outputs of the _original_ computation.  `propagate` is a closure that propagates the gradient of the outputs of this function to the inputs using the chain rule. This is specific to each leaf operator. Its inputs are `dL/dOutputs`, and its outputs are `dL/dInputs`.  The tape is a just a list of accumulated entries recording all compute. We provide a way to reset it so we can run multiple examples."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XmfuXusUFVL3"
      },
      "source": [
        "gradient_tape : List[TapeEntry] = []\n",
        "\n",
        "def reset_tape():\n",
        "  gradient_tape.clear()\n",
        "  global _name\n",
        "  _name = 0 # reset variable names too to keep them small.\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hb--WbxnFvec"
      },
      "source": [
        "Now let's look at how an operator is defined. First we calculate the forward result and create a new Variable to represent it. Then we define the `propagate` closure, which uses the chain rule to backprop the gradient."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qJYuZnYqGZrt"
      },
      "source": [
        "def operator_mul(self : Variable, rhs: Variable) -> Variable:\n",
        "    if isinstance(rhs, float) and rhs == 1.0:\n",
        "        # peephole optimization\n",
        "        return self\n",
        "\n",
        "    # define forward\n",
        "    r = Variable(self.value * rhs.value)\n",
        "    print(f'{r.name} = {self.name} * {rhs.name}')\n",
        "\n",
        "    # record what the inputs and outputs of the op were\n",
        "    inputs = [self.name, rhs.name]\n",
        "    outputs = [r.name]\n",
        "\n",
        "    # define backprop\n",
        "    def propagate(dL_doutputs: List[Variable]):\n",
        "        dL_dr, = dL_doutputs\n",
        "    \n",
        "        dr_dself = rhs # partial derivative of r = self*rhs\n",
        "        dr_drhs = self # partial derivative of r = self*rhs\n",
        "\n",
        "        # chain rule propagation from outputs to inputs of multiply\n",
        "        dL_dself = dL_dr * dr_dself\n",
        "        dL_drhs = dL_dr * dr_drhs\n",
        "        dL_dinputs = [dL_dself, dL_drhs] \n",
        "        return dL_dinputs\n",
        "    # finally, we record the compute we did on the tape\n",
        "    gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))\n",
        "    return r"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bn69AKteGaWN"
      },
      "source": [
        "  Notice how both `rhs` and `self` are captured by this closure.\n",
        "  Their values have to be saved for the backward pass.\n",
        "  PyTorch does something similar, but because PyTorch allows for\n",
        "  mutable tensors, it has additional logic to make sure these captured\n",
        "  variables are not mutated.\n",
        "\n",
        "  We'll define the other operators later. Let's look at how we can define a `grad` function that puts these pieces together. `grad` calculates the gradient of `L` with respect to `desired_results`. We first calculate the gradient of `L` with respect to _all_ computed values and then just extract `desired_results` from them. Real systems do more pruning ahead of time to make sure we are not computing unused values.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_iWNs9YfH8KM"
      },
      "source": [
        "def grad(L, desired_results: List[Variable]) -> List[Variable]:\n",
        "    # this map holds dL/dX for all values X\n",
        "    dL_d : Dict[str, Variable] = {}\n",
        "    # It starts by initializing the 'seed' dL/dL, which is 1\n",
        "    dL_d[L.name] = Variable(torch.ones(()))\n",
        "    print(f'd{L.name} ------------------------')\n",
        "\n",
        "    # look up dL_dentries. If a variable is never used to compute the loss,\n",
        "    # we consider its gradient None, see the note below about zeros for more information.\n",
        "    def gather_grad(entries: List[str]):\n",
        "        return [dL_d[entry] if entry in dL_d else None for entry in entries]\n",
        "\n",
        "    # propagate the gradient information backward\n",
        "    for entry in reversed(gradient_tape):\n",
        "        dL_doutputs = gather_grad(entry.outputs)\n",
        "        if all(dL_doutput is None for dL_doutput in dL_doutputs):\n",
        "            # optimize for the case where some gradient pathways are zero. See\n",
        "            # The note below for more details.\n",
        "            continue\n",
        "\n",
        "        # perform chain rule propagation specific to each compute\n",
        "        dL_dinputs = entry.propagate(dL_doutputs)\n",
        "\n",
        "        # Accululate the gradient produced for each input.\n",
        "        # Each use of a variable produces some gradient dL_dinput for that \n",
        "        # use. The multivariate chain rule tells us it is safe to sum \n",
        "        # all the contributions together.\n",
        "        for input, dL_dinput in zip(entry.inputs, dL_dinputs):\n",
        "            if input not in dL_d:\n",
        "                dL_d[input] = dL_dinput\n",
        "            else:\n",
        "                dL_d[input] += dL_dinput\n",
        "\n",
        "    # print some information to understand the values of each intermediate \n",
        "    for name, value in dL_d.items():\n",
        "        print(f'd{L.name}_d{name} = {value.name}')\n",
        "    print(f'------------------------')\n",
        "\n",
        "    return gather_grad(desired.name for desired in desired_results)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DKF1A2XFKOtj"
      },
      "source": [
        "# Some more operators\n",
        "\n",
        "We'll use these in our examples. Their implementation is very similar to `operator_mul`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tRqwgR64KWLb"
      },
      "source": [
        "def operator_add(self : Variable, rhs: Variable) -> Variable:\n",
        "    # Add follows a similar pattern to Mul, but it doesn't end up\n",
        "    # capturing any variables.\n",
        "    r = Variable(self.value + rhs.value)\n",
        "    print(f'{r.name} = {self.name} + {rhs.name}')\n",
        "    def propagate(dL_doutputs: List[Variable]):\n",
        "        dL_dr, = dL_doutputs\n",
        "        dr_dself = 1.0\n",
        "        dr_drhs = 1.0\n",
        "        dL_dself = dL_dr * dr_dself\n",
        "        dL_drhs = dL_dr * dr_drhs\n",
        "        return [dL_dself, dL_drhs]\n",
        "    gradient_tape.append(TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate))\n",
        "    return r\n",
        "\n",
        "# sum is used to turn our matrices into a single scalar to get a loss.\n",
        "# expand is the backward of sum, so it is added to make sure our Variable\n",
        "# is closed under differentiation. Both have rules similar to mul above.\n",
        "\n",
        "def operator_sum(self: Variable, name: Optional[str]) -> 'Variable':\n",
        "    r = Variable(torch.sum(self.value), name=name)\n",
        "    print(f'{r.name} = {self.name}.sum()')\n",
        "    def propagate(dL_doutputs: List[Variable]):\n",
        "        dL_dr, = dL_doutputs\n",
        "        size = self.value.size()\n",
        "        return [dL_dr.expand(*size)]\n",
        "    gradient_tape.append(TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate))\n",
        "    return r\n",
        "\n",
        "\n",
        "def operator_expand(self: Variable, sizes: List[int]) -> 'Variable':\n",
        "    assert(self.value.dim() == 0) # only works for scalars\n",
        "    r = Variable(self.value.expand(sizes))\n",
        "    print(f'{r.name} = {self.name}.expand({sizes})')\n",
        "    def propagate(dL_doutputs: List[Variable]):\n",
        "        dL_dr, = dL_doutputs\n",
        "        return [dL_dr.sum()]\n",
        "    gradient_tape.append(TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate))\n",
        "    return r"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i1TS_fbDBQry"
      },
      "source": [
        "# Using `grad`\n",
        "Let's use the implementation to calculate some gradients"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "njvtatdLDrBz",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 272
        },
        "outputId": "0002f41e-4ce1-411f-f6bc-5c66f424382a"
      },
      "source": [
        "a_global, b_global = torch.rand(4), torch.rand(4)\n",
        "\n",
        "def simple(a, b):\n",
        "    t = a + b\n",
        "    return t * b\n",
        "\n",
        "reset_tape() # reset any compute from other cells\n",
        "a = Variable.constant(a_global, name='a')\n",
        "b = Variable.constant(b_global, name='b')\n",
        "loss = simple(a, b)\n",
        "da, db = grad(loss, [a, b])\n",
        "print(\"da\", da)\n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.0171, 0.1633, 0.5833, 0.3794])\n",
            "b = tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
            "v0 = a + b\n",
            "v1 = v0 * b\n",
            "dv1 ------------------------\n",
            "v3 = v2 * b\n",
            "v4 = v2 * v0\n",
            "v5 = v4 + v3\n",
            "dv1_dv1 = v2\n",
            "dv1_dv0 = v3\n",
            "dv1_db = v5\n",
            "dv1_da = v3\n",
            "------------------------\n",
            "da tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
            "db tensor([0.7719, 1.4249, 1.6311, 0.6567])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dszwikD7ENj5"
      },
      "source": [
        "# Zero Gradients\n",
        "\n",
        "An interesting case to look at is when the gradient is zero."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PkmEfLA9EV46",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 187
        },
        "outputId": "23de0b78-cd3b-4f4d-f7ca-5dc0aec42b3a"
      },
      "source": [
        "reset_tape()\n",
        "loss = a*a\n",
        "da, db = grad(loss, [a, b])\n",
        "print(\"da\", da)\n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "v0 = a * a\n",
            "dv0 ------------------------\n",
            "v2 = v1 * a\n",
            "v3 = v1 * a\n",
            "v4 = v2 + v3\n",
            "dv0_dv0 = v1\n",
            "dv0_da = v4\n",
            "------------------------\n",
            "da tensor([0.9209, 0.8121, 1.8843, 0.7893])\n",
            "db None\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bwfRHXd3ElEF"
      },
      "source": [
        "Notice that `db` has the value `None`. Another perhaps more mathematically appropriate choice would be to return a 4 element tensor of zeros because a value that does not contribute to the loss will have a gradient of zero. So why do we use `None` instead? The reason is because we want to be able to quickly check that a gradient value is zero, so that we can skip `propgate` functions that involve it in `grad`:\n",
        "\n",
        "```\n",
        "if all(dL_doutput is None for dL_doutput in dL_doutputs):\n",
        "    # optimize for the case where some gradient pathways are zero. See\n",
        "    # The note below for more details.\n",
        "    continue\n",
        "```\n",
        "\n",
        "How does this skipping optimization work? Each propagate function is applying the chain rule.\n",
        "In the general case where there is a vector of inputs and vector\n",
        "of outputs to the function, the jacobean `J` represents the pairwise\n",
        "partial derivatives from each input to each output (`dinput_i/d_output_j`) in matrix form.\n",
        "The multiplication `v*J` (equivalently `J^t*v` if you treat `v` as a column vector) propagates the chain \n",
        "rule backward. This is why propagate is sometimes called the\n",
        "vector-Jacobean product, or `vjp` (and also why forward autodiff uses the Jacobean-vector product).\n",
        "\n",
        "In practice, we do not construct the `J` matrix, because it often\n",
        "has a lot of structure in it. For instance, in pointwise operations,\n",
        "it is a diagonal matrix (input of vector `i` affects only the output of vector `i`).  Constructing it would create `N^2` entries when we only have `N` non-zeros.\n",
        "\n",
        "However, we know that propgate always computes a matrix product\n",
        "against `J`. One important property is if `v` is 0, we know from\n",
        "the fact that matrix multiplication is a linear operator, that `v*J`\n",
        "is also 0. This is what the the `if`-statement is saying. If all the \n",
        "input derivatives are 0, we know the outputs are 0, even without\n",
        "running propagate. This property is important in autograd as we often\n",
        "do more compute that is not related to the loss, and do not\n",
        "want to waste time computing zero gradients for it.\n",
        "\n",
        "This would be more expensive to check if we have to check that each element of a matrix was zero. So we use `None` in grad to represent a value _known_ to be full of only zeros, making the check constant time. PyTorch's autograd does the exact same check. For historical reasons it uses undefined tensors (`at::Tensor()`) in C++ to represent these known-to-be-zero tensors. This has implications for when we generate gradients for aggregate operators as we will see later. When working with the PyTorch autograd, you should keep in mind that undefined tensors are always used to represent these known-to-be-zero values. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D-4kNZj7GWzR"
      },
      "source": [
        "# Gradients of Gradients\n",
        "\n",
        "Notice how the definition of `propagate` works on `Variables` not `Tensors`. This is so that it can calculate the gradient of some other gradient. Just think of the first gradient computation like any other compute you can do. There is no reason why you can't take a gradient of that compute as well. As a concrete example lets look at this code:\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_-Vjp9w4G8qo",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 884
        },
        "outputId": "698a70ad-01c7-4578-c089-9995a8bf3389"
      },
      "source": [
        "def run_gradients(my_fn, second_loss=True):\n",
        "    reset_tape()\n",
        "    a = Variable.constant(a_global, name='a')\n",
        "    b = Variable.constant(b_global, name='b')\n",
        "\n",
        "    # our first loss\n",
        "    L0 = (my_fn(a, b)).sum(name='L0')\n",
        "\n",
        "    # compute derivatives of our inputs\n",
        "    dL0_da, dL0_db = grad(L0, [a, b])\n",
        "    if not second_loss:\n",
        "      return dL0_da, dL0_db\n",
        "\n",
        "    # now lets compute the L2 norm of our derivatives\n",
        "    L1 = (dL0_da*dL0_da + dL0_db*dL0_db).sum(name='L1')\n",
        "\n",
        "    # and take the gradient of that.\n",
        "    # notice there are two losses involved.\n",
        "    dL1_da, dL1_db = grad(L1, [a, b])\n",
        "    return dL1_da, dL1_db\n",
        "\n",
        "da, db = run_gradients(simple)\n",
        "print(\"da\", da)\n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "v0 = a + b\n",
            "v1 = v0 * b\n",
            "L0 = v1.sum()\n",
            "dL0 ------------------------\n",
            "v3 = v2.expand(4)\n",
            "v4 = v3 * b\n",
            "v5 = v3 * v0\n",
            "v6 = v5 + v4\n",
            "dL0_dL0 = v2\n",
            "dL0_dv1 = v3\n",
            "dL0_dv0 = v4\n",
            "dL0_db = v6\n",
            "dL0_da = v4\n",
            "------------------------\n",
            "v7 = v4 * v4\n",
            "v8 = v6 * v6\n",
            "v9 = v7 + v8\n",
            "L1 = v9.sum()\n",
            "dL1 ------------------------\n",
            "v11 = v10.expand(4)\n",
            "v12 = v11 * v6\n",
            "v13 = v11 * v6\n",
            "v14 = v12 + v13\n",
            "v15 = v11 * v4\n",
            "v16 = v11 * v4\n",
            "v17 = v15 + v16\n",
            "v18 = v17 + v14\n",
            "v19 = v14 * v0\n",
            "v20 = v14 * v3\n",
            "v21 = v18 * b\n",
            "v22 = v18 * v3\n",
            "v23 = v19 + v21\n",
            "v24 = v23.sum()\n",
            "v25 = v22 + v20\n",
            "dL1_dL1 = v10\n",
            "dL1_dv9 = v11\n",
            "dL1_dv7 = v11\n",
            "dL1_dv8 = v11\n",
            "dL1_dv6 = v14\n",
            "dL1_dv4 = v18\n",
            "dL1_dv5 = v14\n",
            "dL1_dv3 = v23\n",
            "dL1_dv0 = v20\n",
            "dL1_db = v25\n",
            "dL1_dv2 = v24\n",
            "dL1_da = v20\n",
            "------------------------\n",
            "da tensor([1.2611, 2.1304, 5.8394, 3.3869])\n",
            "db tensor([ 2.6923,  4.9201, 13.6563,  8.0727])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pxdBcwoPHvRE"
      },
      "source": [
        "Notice how the `gradient_tape` just keeps accumulating more entries as we run `grad` twice. This is because in the second call to `grad` we still have to consider all the pathways through which the gradient flows all the way from `L1` back to the inputs `a` and `b`.  One implication is that the entries that are run in the first call to `grad` actually get run _again_ in the second call to `grad`.  In practice this means that if you append a `propagate` function to the tape in a gradient-of-gradient scenario, you should expect it to run multiple times! If a single gradient compute is \"forward, backward\", then a gradient of gradient compute could be thought of as \"forward-part-0, backward-part-0, foward-part-1, backward-part-1, backward-part-0 (again)\".\n",
        "\n",
        "Issues with how autograd functions behave often _only_ appear when considering higher order gradients so it is important to test changes on these cases. We'll see an example later."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XVos1LB_JgA5"
      },
      "source": [
        "# Rules of thumb for Autograd\n",
        "\n",
        "## Every use of a Variable generates a gradient specific to that use\n",
        "\n",
        "If you use a temporary variable `t` in two different subsequent computations, each _use_ of that value will have a gradient associated with it from the using operator. The multivariate chain rule tells us we can sum these gradients to get the overall contribute of `t`. We always have to account for all uses of a variable. If we forget about one, we will calculate the wrong value.\n",
        "\n",
        "## Inputs become outputs, outputs become inputs, reads become writes, writes become reads\n",
        "\n",
        "When we record a `TapeEntry` we also record the inputs and outputs of the compute _from the perspective of the forward pass_. The inputs/outputs of the propgate function in the backward pass are _flipped_. You get `dL/doutputs` and you produce `dL/dinputs`. It is easy to get confused by names like input or output. You have to keep in mind what they are relative to. A corrolary here occurs at the level of compute. Because every read of a value in a matrix produces a gradient, it implies that in the backward pass we will be computing (and writing) a value for every read in forward. For instance, the `sum` operator reads an entire matrix and produces one value. So its reverse must be an operator that reads one value and writes an entire matrix.  Indeed, the backward of `sum` is `expand`, which does precisely that.\n",
        "\n",
        "## Each call to grad produces gradients for a different loss\n",
        "\n",
        "When you call `grad(l, [a,b])` you are computing a set of gradients `dl_da`, `dl_db`. A subsquent call to `grad` will use a different loss, and potentially care about different inputs. If you abbreviate the loss, e.g. by saying `da`, you better be sure there aren't additional losses or you will quickly get confused. Gradient-of-gradient, or higher-order gradient, just means that we are computing some loss that was based on the gradients of another loss. There are an infinite number of calculations that compute gradients. There isn't a single \"grad-of-grad\" compute.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SI-bjH-1NHNx"
      },
      "source": [
        "# Creating Aggregate Operations\n",
        "\n",
        "While autograd is a great way to piece together fundamental operators, sometimes you want to create aggregate operators that do not perform autograd operations internally. Fusion is one common example of this where, for instance, you may want to generate a single CUDA kernel for `(a + b)*b`. TorchScript's symbolic autograd internally can separate this compute from autograd and generate explicit forward/backward passes. Let's look at what issues arise when trying to do this. This should help with understanding the PyTorch implementation, and also make it possible to create custom aggregate operators with correct autograd implementations. Let's try to turn our `simple` function from before into one that computes its entire body as an aggregate:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bqpAiClBJnM2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 443
        },
        "outputId": "713e7e56-0bd6-4677-f657-712c55fbe51f"
      },
      "source": [
        "def simple(a, b):\n",
        "    t = a + b\n",
        "    return t * b\n",
        "\n",
        "def simple_type_error(a, b):\n",
        "    t = (a.value + b.value)\n",
        "    r = Variable(t * b.value)\n",
        "    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
        "        # manually apply the chain rule to the compute,\n",
        "        # in practice a symbolic differentiator might create this code\n",
        "        dL_dr, = dL_doutputs\n",
        "        dr_dt = b # partial from: r = t * b\n",
        "        dr_db = t # partial from: r = t * b\n",
        "        dL_dt = dL_dr*dr_dt # chain rule\n",
        "        dt_da = 1.0 # partial from t = a + b\n",
        "        dt_db = 1.0 # partial from t = a + b\n",
        "        dL_da = dL_dt * dt_da # chain rule\n",
        "        dL_db = dL_dt * dt_db + dL_dr * dr_db # ERROR! dr_db is a Tensor not a Variable\n",
        "        return [dL_da, dL_db]\n",
        "    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))\n",
        "    return r\n",
        "\n",
        "da, db = run_gradients(simple_type_error)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = v2 * b\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "error",
          "ename": "AttributeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-33-443c996d15d3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_gradients\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msimple_type_error\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m<ipython-input-32-4a6e63ff86b6>\u001b[0m in \u001b[0;36mrun_gradients\u001b[0;34m(my_fn, second_loss)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;31m# compute derivatives of our inputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mdL0_da\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdL0_db\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mL0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msecond_loss\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mdL0_da\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdL0_db\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-28-73c72df34bbd>\u001b[0m in \u001b[0;36mgrad\u001b[0;34m(L, desired_results)\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m         \u001b[0;31m# perform chain rule propagation specific to each compute\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m         \u001b[0mdL_dinputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mentry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpropagate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdL_doutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m         \u001b[0;31m# Accululate the gradient produced for each input.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-33-443c996d15d3>\u001b[0m in \u001b[0;36mpropagate\u001b[0;34m(dL_doutputs)\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mdt_db\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1.0\u001b[0m \u001b[0;31m# partial from t = a + b\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m         \u001b[0mdL_da\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdL_dt\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdt_da\u001b[0m \u001b[0;31m# chain rule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m         \u001b[0mdL_db\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdL_dt\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdt_db\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdL_dr\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdr_db\u001b[0m \u001b[0;31m# ERROR! dr_db is a Tensor not a Variable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mdL_da\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdL_db\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0mgradient_tape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTapeEntry\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpropagate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpropagate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-24-264419a9e1d2>\u001b[0m in \u001b[0;36m__mul__\u001b[0;34m(self, rhs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__mul__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Variable'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Variable'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m         \u001b[0;31m# defined later in the notebook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0moperator_mul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__add__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrhs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Variable'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;34m'Variable'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-27-802a827cdde8>\u001b[0m in \u001b[0;36moperator_mul\u001b[0;34m(self, rhs)\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0;31m# define forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mrhs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{r.name} = {self.name} * {rhs.name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'value'"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pa5P78Y4PHfT"
      },
      "source": [
        "This doesn't work because `t` is being captured and used in propagate, but propgate expects to compute on Variables. Becuase `t` was extracted from autograd, it can no longer directly participate in the `propagate` call. One way to fix this is to recompute `t`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gY9_KfgZPgql",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "f64d673a-7e95-42d6-b605-9ceec8efd6ae"
      },
      "source": [
        "def simple_recompute(a, b):\n",
        "    t = (a.value + b.value)\n",
        "    r = Variable(t * b.value)\n",
        "    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
        "        dL_dr, = dL_doutputs\n",
        "        dr_dt = b # partial from: r = t * b\n",
        "        t = a + b # RECOMPUTE!\n",
        "        dr_db = t # partial from: r = t * b\n",
        "        dL_dt = dL_dr*dr_dt # chain rule\n",
        "        dt_da = 1.0 # partial from t = a + b\n",
        "        dt_db = 1.0 # partial from t = a + b\n",
        "        dL_da = dL_dt * dt_da # chain rule\n",
        "        dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule\n",
        "        return [dL_da, dL_db]\n",
        "    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))\n",
        "    return r\n",
        "\n",
        "da, db = run_gradients(simple_recompute)\n",
        "da_ref, db_ref = run_gradients(simple)\n",
        "print(\"da\", da, da_ref)\n",
        "print(\"db\", db, db_ref)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = a + b\n",
            "v4 = v2 * b\n",
            "v5 = v2 * v3\n",
            "v6 = v4 + v5\n",
            "dL0_dL0 = v1\n",
            "dL0_dv0 = v2\n",
            "dL0_da = v4\n",
            "dL0_db = v6\n",
            "------------------------\n",
            "v7 = v4 * v4\n",
            "v8 = v6 * v6\n",
            "v9 = v7 + v8\n",
            "L1 = v9.sum()\n",
            "dL1 ------------------------\n",
            "v11 = v10.expand(4)\n",
            "v12 = v11 * v6\n",
            "v13 = v11 * v6\n",
            "v14 = v12 + v13\n",
            "v15 = v11 * v4\n",
            "v16 = v11 * v4\n",
            "v17 = v15 + v16\n",
            "v18 = v17 + v14\n",
            "v19 = v14 * v3\n",
            "v20 = v14 * v2\n",
            "v21 = v18 * b\n",
            "v22 = v18 * v2\n",
            "v23 = v19 + v21\n",
            "v24 = v22 + v20\n",
            "v25 = v23.sum()\n",
            "dL1_dL1 = v10\n",
            "dL1_dv9 = v11\n",
            "dL1_dv7 = v11\n",
            "dL1_dv8 = v11\n",
            "dL1_dv6 = v14\n",
            "dL1_dv4 = v18\n",
            "dL1_dv5 = v14\n",
            "dL1_dv2 = v23\n",
            "dL1_dv3 = v20\n",
            "dL1_db = v24\n",
            "dL1_da = v20\n",
            "dL1_dv1 = v25\n",
            "------------------------\n",
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "v0 = a + b\n",
            "v1 = v0 * b\n",
            "L0 = v1.sum()\n",
            "dL0 ------------------------\n",
            "v3 = v2.expand(4)\n",
            "v4 = v3 * b\n",
            "v5 = v3 * v0\n",
            "v6 = v5 + v4\n",
            "dL0_dL0 = v2\n",
            "dL0_dv1 = v3\n",
            "dL0_dv0 = v4\n",
            "dL0_db = v6\n",
            "dL0_da = v4\n",
            "------------------------\n",
            "v7 = v4 * v4\n",
            "v8 = v6 * v6\n",
            "v9 = v7 + v8\n",
            "L1 = v9.sum()\n",
            "dL1 ------------------------\n",
            "v11 = v10.expand(4)\n",
            "v12 = v11 * v6\n",
            "v13 = v11 * v6\n",
            "v14 = v12 + v13\n",
            "v15 = v11 * v4\n",
            "v16 = v11 * v4\n",
            "v17 = v15 + v16\n",
            "v18 = v17 + v14\n",
            "v19 = v14 * v0\n",
            "v20 = v14 * v3\n",
            "v21 = v18 * b\n",
            "v22 = v18 * v3\n",
            "v23 = v19 + v21\n",
            "v24 = v23.sum()\n",
            "v25 = v22 + v20\n",
            "dL1_dL1 = v10\n",
            "dL1_dv9 = v11\n",
            "dL1_dv7 = v11\n",
            "dL1_dv8 = v11\n",
            "dL1_dv6 = v14\n",
            "dL1_dv4 = v18\n",
            "dL1_dv5 = v14\n",
            "dL1_dv3 = v23\n",
            "dL1_dv0 = v20\n",
            "dL1_db = v25\n",
            "dL1_dv2 = v24\n",
            "dL1_da = v20\n",
            "------------------------\n",
            "da tensor([1.2611, 2.1304, 5.8394, 3.3869]) tensor([1.2611, 2.1304, 5.8394, 3.3869])\n",
            "db tensor([ 2.6923,  4.9201, 13.6563,  8.0727]) tensor([ 2.6923,  4.9201, 13.6563,  8.0727])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HLUaN1UeP4C_"
      },
      "source": [
        "This recompute works but it is not ideal. First, the original compute may have been expensive (think a bunch of convolutions and multiplies), so redoing it in the backward pass may take significant time. Second, we need to save `a` and `b` to recompute `t`. Previously we only had to save `b`. What if `a` was a _huge_ matrix but `t` was small? Then we are using _more total memory_ by doing this recompute as well. In general, we want to avoid recomputing things unless we know it won't be expensive in time or space.\n",
        "\n",
        "Let's consider another approach. What happens if we just make `t` into a Variable?"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZIysmiVCQaw5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 799
        },
        "outputId": "7b1a0d2a-a546-436f-af00-9b5d5cc53986"
      },
      "source": [
        "def simple_variable_wrong(a, b):\n",
        "    t = (a.value + b.value)\n",
        "    t_v = Variable(t, name='t') # named for debugging\n",
        "    r = Variable(t * b.value)\n",
        "    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
        "        dL_dr, = dL_doutputs\n",
        "        dr_dt = b # partial from: r = t * b\n",
        "        dr_db = t_v # partial from: r = t * b\n",
        "        dL_dt = dL_dr*dr_dt # chain rule\n",
        "        dt_da = 1.0 # partial from t = a + b\n",
        "        dt_db = 1.0 # partial from t = a + b\n",
        "        dL_da = dL_dt * dt_da # chain rule\n",
        "        dL_db = dL_dt * dt_db + dL_dr * dr_db # chain rule\n",
        "        return [dL_da, dL_db]\n",
        "    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name], propagate=propagate))\n",
        "    return r\n",
        "\n",
        "da, db = run_gradients(simple_variable_wrong)\n",
        "print(\"da\", da) # ERROR: da is None!!!????\n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = v2 * b\n",
            "v4 = v2 * t\n",
            "v5 = v3 + v4\n",
            "dL0_dL0 = v1\n",
            "dL0_dv0 = v2\n",
            "dL0_da = v3\n",
            "dL0_db = v5\n",
            "------------------------\n",
            "v6 = v3 * v3\n",
            "v7 = v5 * v5\n",
            "v8 = v6 + v7\n",
            "L1 = v8.sum()\n",
            "dL1 ------------------------\n",
            "v10 = v9.expand(4)\n",
            "v11 = v10 * v5\n",
            "v12 = v10 * v5\n",
            "v13 = v11 + v12\n",
            "v14 = v10 * v3\n",
            "v15 = v10 * v3\n",
            "v16 = v14 + v15\n",
            "v17 = v16 + v13\n",
            "v18 = v13 * t\n",
            "v19 = v13 * v2\n",
            "v20 = v17 * b\n",
            "v21 = v17 * v2\n",
            "v22 = v18 + v20\n",
            "v23 = v22.sum()\n",
            "dL1_dL1 = v9\n",
            "dL1_dv8 = v10\n",
            "dL1_dv6 = v10\n",
            "dL1_dv7 = v10\n",
            "dL1_dv5 = v13\n",
            "dL1_dv3 = v17\n",
            "dL1_dv4 = v13\n",
            "dL1_dv2 = v22\n",
            "dL1_dt = v19\n",
            "dL1_db = v21\n",
            "dL1_dv1 = v23\n",
            "------------------------\n",
            "da None\n",
            "db tensor([1.4312, 2.7896, 7.8169, 4.6857])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v6D1xMlfQ4xE"
      },
      "source": [
        "While we do not get an error, something is clearly wrong. `dL1/da` is None, but we _know_ that the value of `a` affects the norm of the gradients of the original loss so this value should not be None. We are not propagating a gradient somewhere!\n",
        "\n",
        "Let's see what happens when we run just the first gradient.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KsCpcoCYRVgR",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 544
        },
        "outputId": "c67a729e-0dfe-4e0e-b267-50d7ee05e307"
      },
      "source": [
        "da, db = run_gradients(simple_variable_wrong, second_loss=False)\n",
        "da_ref, db_ref = run_gradients(simple, second_loss=False)\n",
        "print(\"da\", da, da_ref) \n",
        "print(\"db\", db, db_ref)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = v2 * b\n",
            "v4 = v2 * t\n",
            "v5 = v3 + v4\n",
            "dL0_dL0 = v1\n",
            "dL0_dv0 = v2\n",
            "dL0_da = v3\n",
            "dL0_db = v5\n",
            "------------------------\n",
            "a = tensor([0.4605, 0.4061, 0.9422, 0.3946])\n",
            "b = tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "v0 = a + b\n",
            "v1 = v0 * b\n",
            "L0 = v1.sum()\n",
            "dL0 ------------------------\n",
            "v3 = v2.expand(4)\n",
            "v4 = v3 * b\n",
            "v5 = v3 * v0\n",
            "v6 = v5 + v4\n",
            "dL0_dL0 = v2\n",
            "dL0_dv1 = v3\n",
            "dL0_dv0 = v4\n",
            "dL0_db = v6\n",
            "dL0_da = v4\n",
            "------------------------\n",
            "da tensor([0.0850, 0.3296, 0.9888, 0.6494]) tensor([0.0850, 0.3296, 0.9888, 0.6494])\n",
            "db tensor([0.6306, 1.0652, 2.9197, 1.6935]) tensor([0.6306, 1.0652, 2.9197, 1.6935])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ck3a9nQRil7"
      },
      "source": [
        "In the single-backward case, we get the right answer! This illustrates a key part of autograd: it is _very easy_ to make it appear to work for a single backward pass but have the code be broken when trying higher order gradients. \n",
        "\n",
        "So what is going wrong? Look at the debug trace from the first time we ran `simple_variable_wrong`. Inside the compute of `dL0` (the first backward), you can see a line: `v4 = v2 * t`. The first backward is using the value of `t`. But if a computation _uses_ `t` then the gradient of that computation will have a non-zero gradient `dL1/dt` for any future loss (`L1`) that uses the results of that computation. But this future use of `t` is not accounted for in `simple_variable_wrong`! We consider the effect of `r` on `t` as `dL_dt = dL_dr*dr_dt`, but do not consider uses of `t` outside the local aggregate. This is because the way `t` can be used in the future is subtle: it escapes from our compute _only_ through its use as a closed over variable in `propagate`. So this gradient pathway can only be non-zero for higher-order gradients where we are differentiating through this use.\n",
        "\n",
        "The problem originates because `t` was not declared as an output of the original computation, even though it was defined by the computation and used by later computations. We can fix this by defining it as an output in the gradient tape and then using the derivative contribution that comes from it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2cbRekShbmxX",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 986
        },
        "outputId": "950d5943-f63f-4ff3-e370-72e9316d340b"
      },
      "source": [
        "def simple_variable_almost(a, b):\n",
        "    t = (a.value + b.value)\n",
        "    t_v = Variable(t, name='t_v')\n",
        "    r = Variable(t * b.value)\n",
        "    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
        "        # t is considered an output, so we now get dL_dt0 as an input.\n",
        "        dL_dr, dL_dt0 = dL_doutputs\n",
        "               ###### new gradient contribution\n",
        "\n",
        "        # Handle cases where one incoming gradient is zero (None)\n",
        "        if dL_dr is None:\n",
        "          dL_dr = Variable.constant(torch.zeros(()))\n",
        "        if dL_dt0 is None:\n",
        "          dL_dt0 = Variable.constant(torch.zeros(()))\n",
        "               \n",
        "\n",
        "        dr_dt = b \n",
        "        dr_db = t_v \n",
        "        # we combine this with the contribution from r to calculate \n",
        "        # all gradient paths to dL_dt\n",
        "        dL_dt = dL_dt0 + dL_dr*dr_dt # chain rule\n",
        "                ######\n",
        "\n",
        "        dt_da = 1.0 \n",
        "        dt_db = 1.0 \n",
        "        dL_db = dL_dr * dr_db + dL_dt * dt_db \n",
        "        dL_da = dL_dt * dt_da\n",
        "        return [dL_da, dL_db]\n",
        "\n",
        "    # note: t_v is now considered an output in the tape\n",
        "    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))\n",
        "                                                                             ######### new output\n",
        "    return r\n",
        "da, db = run_gradients(simple_variable_almost)\n",
        "print(\"da\", da) \n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.0171, 0.1633, 0.5833, 0.3794])\n",
            "b = tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = 0.0\n",
            "v4 = v2 * b\n",
            "v5 = v3 + v4\n",
            "v6 = v2 * t_v\n",
            "v7 = v6 + v5\n",
            "dL0_dL0 = v1\n",
            "dL0_dv0 = v2\n",
            "dL0_da = v5\n",
            "dL0_db = v7\n",
            "------------------------\n",
            "v8 = v5 * v5\n",
            "v9 = v7 * v7\n",
            "v10 = v8 + v9\n",
            "L1 = v10.sum()\n",
            "dL1 ------------------------\n",
            "v12 = v11.expand(4)\n",
            "v13 = v12 * v7\n",
            "v14 = v12 * v7\n",
            "v15 = v13 + v14\n",
            "v16 = v12 * v5\n",
            "v17 = v12 * v5\n",
            "v18 = v16 + v17\n",
            "v19 = v18 + v15\n",
            "v20 = v15 * t_v\n",
            "v21 = v15 * v2\n",
            "v22 = v19 * b\n",
            "v23 = v19 * v2\n",
            "v24 = v20 + v22\n",
            "v25 = v24.sum()\n",
            "v26 = 0.0\n",
            "v27 = v26 * b\n",
            "v28 = v21 + v27\n",
            "v29 = v26 * t_v\n",
            "v30 = v29 + v28\n",
            "v31 = v23 + v30\n",
            "dL1_dL1 = v11\n",
            "dL1_dv10 = v12\n",
            "dL1_dv8 = v12\n",
            "dL1_dv9 = v12\n",
            "dL1_dv7 = v15\n",
            "dL1_dv5 = v19\n",
            "dL1_dv6 = v15\n",
            "dL1_dv2 = v24\n",
            "dL1_dt_v = v21\n",
            "dL1_dv3 = v19\n",
            "dL1_dv4 = v19\n",
            "dL1_db = v31\n",
            "dL1_dv1 = v25\n",
            "dL1_da = v28\n",
            "------------------------\n",
            "da tensor([1.5438, 2.8499, 3.2622, 1.3134])\n",
            "db tensor([3.8424, 6.9614, 7.5721, 2.9042])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GjIEL9nncdmO"
      },
      "source": [
        "This code is now correct! However, it has some non-optimal behavior. Notice how at the beginning of `propagate` we need to handle the cases where the gradients coming in are `None`. Recall that when a pathway has no gradient we give it the value `None`. The first time through `propagate` `dL_dt0` will be `None` since `t` is not used outside of the propagate function itself on the first backward. The _second_ time through `propgate`, `dL_dt0` will have a value but `dL_dr` will be `None`. Excercise: convince yourself why `dL_dr` is `None` the second time through. When we fix this by changing the `None` into zeros, we get the right answer but at the cost of always doing more compute. For instance in this case, it adds an additional pointwise addition of a zero tensor to every single-backward call to handle `dL_dt0` input which will be zero.\n",
        "\n",
        " It makes sense to use a constant-time check for zero to eliminate a tensor-sized amount of work. So we optimize this code by replicating some of the `None` handling logic in `grad` directly into the aggregate op. It is a little messy but it handles the cases where inputs might be `None` with a minimal amount of compute."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fshJIV4xcJKW",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 833
        },
        "outputId": "90eee555-de63-4fa5-fdf1-313d9fae8116"
      },
      "source": [
        "def add_optional(a: Optional['Variable'], b: Optional['Variable']):\n",
        "    if a is None:\n",
        "        return b\n",
        "    if b is None:\n",
        "        return a\n",
        "    return a + b\n",
        "\n",
        "def simple_variable(a, b):\n",
        "    t = (a.value + b.value)\n",
        "    t_v = Variable(t, name='t_v')\n",
        "    r = Variable(t * b.value)\n",
        "    def propagate(dL_doutputs: List[Variable]) -> List[Variable]:\n",
        "        dL_dr, dL_dt0 = dL_doutputs\n",
        "        dr_dt = b # partial from: r = t * b\n",
        "        dr_db = t_v # partial from: r = t * b\n",
        "        dL_dt = dL_dt0\n",
        "        if dL_dr is not None:\n",
        "            dL_dt = add_optional(dL_dt, dL_dr*dr_dt) # chain rule\n",
        "\n",
        "        dt_da = 1.0 # partial from t = a + b\n",
        "        dt_db = 1.0 # partial from t = a + b\n",
        "        if dL_dr is not None:\n",
        "            dL_db = dL_dr * dr_db # chain rule\n",
        "        else:\n",
        "            dL_db = None\n",
        "\n",
        "        if dL_dt is not None:\n",
        "            dL_da = dL_dt * dt_da # chain rule\n",
        "            dL_db = add_optional(dL_db, dL_dt * dt_db)\n",
        "        else:\n",
        "            dL_da = None\n",
        "\n",
        "        return [dL_da, dL_db]\n",
        "\n",
        "    gradient_tape.append(TapeEntry(inputs=[a.name, b.name], outputs=[r.name, t_v.name], propagate=propagate))\n",
        "    return r\n",
        "da, db = run_gradients(simple_variable)\n",
        "print(\"da\", da) \n",
        "print(\"db\", db)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "a = tensor([0.0171, 0.1633, 0.5833, 0.3794])\n",
            "b = tensor([0.3774, 0.6308, 0.5239, 0.1387])\n",
            "L0 = v0.sum()\n",
            "dL0 ------------------------\n",
            "v2 = v1.expand(4)\n",
            "v3 = v2 * b\n",
            "v4 = v2 * t_v\n",
            "v5 = v4 + v3\n",
            "dL0_dL0 = v1\n",
            "dL0_dv0 = v2\n",
            "dL0_da = v3\n",
            "dL0_db = v5\n",
            "------------------------\n",
            "v6 = v3 * v3\n",
            "v7 = v5 * v5\n",
            "v8 = v6 + v7\n",
            "L1 = v8.sum()\n",
            "dL1 ------------------------\n",
            "v10 = v9.expand(4)\n",
            "v11 = v10 * v5\n",
            "v12 = v10 * v5\n",
            "v13 = v11 + v12\n",
            "v14 = v10 * v3\n",
            "v15 = v10 * v3\n",
            "v16 = v14 + v15\n",
            "v17 = v16 + v13\n",
            "v18 = v13 * t_v\n",
            "v19 = v13 * v2\n",
            "v20 = v17 * b\n",
            "v21 = v17 * v2\n",
            "v22 = v18 + v20\n",
            "v23 = v22.sum()\n",
            "v24 = v21 + v19\n",
            "dL1_dL1 = v9\n",
            "dL1_dv8 = v10\n",
            "dL1_dv6 = v10\n",
            "dL1_dv7 = v10\n",
            "dL1_dv5 = v13\n",
            "dL1_dv3 = v17\n",
            "dL1_dv4 = v13\n",
            "dL1_dv2 = v22\n",
            "dL1_dt_v = v19\n",
            "dL1_db = v24\n",
            "dL1_dv1 = v23\n",
            "dL1_da = v19\n",
            "------------------------\n",
            "da tensor([1.5438, 2.8499, 3.2622, 1.3134])\n",
            "db tensor([3.8424, 6.9614, 7.5721, 2.9042])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r77nkn7Y5NjU"
      },
      "source": [
        "**Excercise** modify `run_gradients` such that the second call to `grad` produces non-zero values for both `dL_dr` and `dL_dt`. Hint: it can be done with the addition of 2 characters."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lZvYIrBS5fXr"
      },
      "source": [
        "In PyTorch's symbolic autodiff implementation, the handling of zero tensors is done with undefined tensors in the place of `None` values, but the logic in TorchScript is very similar. The function `any_defined(...)` is used to check if any inputs are non-zero and guards the calculation of unused parts of the autograd. The `AutogradAdd(a, b)` operator adds two tensors, handling the case where either is undefined, similar to `add_optional`. \n",
        "\n",
        "The backward pass is very messy as-is with all of this conditional logic. Furthermore, as you have seen in these examples, in many cases the logic will branch in the same direction. This is especially true for single-backward where gradients from captured temporaries will always be zero. It is profitable to try to specialize this code for particular patterns of non-zeros since it allows more aggresive fusion of the backward pass."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sAApdEb27KTA"
      },
      "source": [
        "# PyTorch vs Simple Grad\n",
        "\n",
        "Simple Grad gives a good overview of how PyTorch's autograd works. TorchScript's symbolic gradient pass can generate aggregate operators from subsets of the IR by automating the process we went through to define `simple` as an aggregate operator.\n",
        "\n",
        "The real PyTorch autograd has some features that go beyond this example related to mutable tensors. Simple Grad assumes that tensors are immutable, so saving a Tensor for use in `propagate` is as simple as saving a reference to it. In PyTorch, the gradient formulas need to explicity mark a Tensor as needing to be saved so we can track future potential mutations. The idea is to be able to track if a user mutated a tensor that is needed by backward and report an error on use. Mutable ops themselves also affect how the `propagate` functions get recorded. If a tensor is mutated, uses of the tensor _before_ the mutation need to propagate gradient to the original value, while uses _after_ propagate gradient to the new mutated value. Since tensors can be views of other mutable tensors, PyTorch needs bookkeeping to make sure any time a tensor is updated all views of the tensor now propagate gradient to the new value and not the old one. \n",
        "\n",
        "# Where to go from here\n",
        "\n",
        "If you still have questions about how this process works, I encourage you to edit this notebook with additional debug information and play around with compute. You can try:\n",
        "* Adding a new operator with `propagate` formula (use torch.grad to verify correctness)\n",
        "* Modify `run_gradient` to calculate weirder higher order gradients and see if it behaves as you expect.\n",
        "* Remove `None` and implement gradients using Tensor zeros.\n",
        "* Try to manually define an another aggregate operator for something similar to `simple`\n",
        "* Write a 'compiler' that can take a small expression similar to `simple` and transform it automatically into a forward and `propagate`, similar to autodiff.cpp\n",
        "* Rewrite `simple_variable` so all the branching for `None` checks is at the top of `propagate`. Can you generalize this such that a compiler can generate specializations for the seen non-zero patterns?\n",
        "* Read `autodiff.cpp` and add a description to this document about how concenpts in here directly relate to that code."
      ]
    }
  ]
}